# coding=utf-8
# Copyright 2021 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of multiheaded FAVOR-attention & FAVOR-self-attention layers.

Prefix Sum Tensorflow implementation by Valerii Likhosherstov.
"""
import math
import numpy as np
import tensorflow as tf
import util

BIG_CONSTANT = 1e8


def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False):
    r"""Constructs the matrix of random projections.

    Constructs a matrix of random orthogonal projections. Each projection vector
    has direction chosen uniformly at random and either deterministic length
    \sqrt{d} or length taken from the \chi(d) distribution (in the latter case
    marginal distributions of the projections are d-dimensional Gaussian vectors
    with associated identity covariance matrix).

    Args:
      m: number of random projections.
      d: dimensionality of each random projection.
      seed: random seed used to construct projections.
      scaling: 1 if all the random projections need to be renormalized to have
        length \sqrt{d}, 0 if the lengths of random projections should follow
        \chi(d) distribution.
      struct_mode: if True then products of Givens rotations will be used to
        construct random orthogonal matrix. This bypasses Gram-Schmidt
        orthogonalization.

    Returns:
      The matrix of random projections of the shape [m, d].
    """
    nb_full_blocks = int(m / d)
    block_list = []
    current_seed = seed
    for _ in range(nb_full_blocks):
        if struct_mode:
            q = create_products_of_givens_rotations(d, seed)
        else:
            unstructured_block = tf.random.normal((d, d), seed=current_seed)
            q, _ = tf.linalg.qr(unstructured_block)
            q = tf.transpose(q)
        block_list.append(q)
        current_seed += 1
    remaining_rows = m - nb_full_blocks * d
    if remaining_rows > 0:
        if struct_mode:
            q = create_products_of_givens_rotations(d, seed)
        else:
            unstructured_block = tf.random.normal((d, d), seed=current_seed)
            q, _ = tf.linalg.qr(unstructured_block)
            q = tf.transpose(q)
        block_list.append(q[0:remaining_rows])
    final_matrix = tf.experimental.numpy.vstack(block_list)
    current_seed += 1

    if scaling == 0:
        multiplier = tf.norm(tf.random.normal((m, d), seed=current_seed), axis=1)
    elif scaling == 1:
        multiplier = tf.math.sqrt(float(d)) * tf.ones((m))
    else:
        raise ValueError("Scaling must be one of {0, 1}. Was %s" % scaling)

    return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)


def create_products_of_givens_rotations(dim, seed):
    r"""Constructs a 2D-tensor which is a product of Givens random rotations.

    Constructs a 2D-tensor of the form G_1 * ... * G_k, where G_i is a Givens
    random rotation. The resulting tensor mimics a matrix taken uniformly at
    random form the orthogonal group.

    Args:
      dim: number of rows/columns of the resulting 2D-tensor.
      seed: random seed.

    Returns:
      The product of Givens random rotations.
    """
    nb_givens_rotations = dim * int(math.ceil(math.log(float(dim))))
    q = np.eye(dim, dim)
    np.random.seed(seed)
    for _ in range(nb_givens_rotations):
        random_angle = math.pi * np.random.uniform()
        random_indices = np.random.choice(dim, 2)
        index_i = min(random_indices[0], random_indices[1])
        index_j = max(random_indices[0], random_indices[1])
        slice_i = q[index_i]
        slice_j = q[index_j]
        new_slice_i = math.cos(random_angle) * slice_i + math.sin(
            random_angle) * slice_j
        new_slice_j = -math.sin(random_angle) * slice_i + math.cos(
            random_angle) * slice_j
        q[index_i] = new_slice_i
        q[index_j] = new_slice_j
    return tf.cast(tf.constant(q), dtype=tf.float32)


def relu_kernel_transformation(data,
                               is_query,
                               projection_matrix=None,
                               numerical_stabilizer=0.001):
    """Computes features for the ReLU-kernel.

    Computes random features for the ReLU kernel from
    https://arxiv.org/pdf/2009.14794.pdf.

    Args:
      data: input data tensor of the shape [B, L, H, D], where: B - batch
        dimension, L - attention dimensions, H - heads, D - features.
      is_query: indicates whether input data is a query oor key tensor.
      projection_matrix: random Gaussian matrix of shape [M, D], where M stands
        for the number of random features and each D x D sub-block has pairwise
        orthogonal rows.
      numerical_stabilizer: small positive constant for numerical stability.

    Returns:
      Corresponding kernel feature map.
    """
    del is_query
    if projection_matrix is None:
        return tf.nn.relu(data) + numerical_stabilizer
    else:
        ratio = 1.0 / tf.math.sqrt(
            tf.dtypes.cast(projection_matrix.shape[0], tf.float32))
        data_dash = ratio * tf.einsum("blhd,md->blhm", data, projection_matrix)
        return tf.nn.relu(data_dash) + numerical_stabilizer


def softmax_kernel_transformation(data,
                                  is_query,
                                  projection_matrix=None,
                                  numerical_stabilizer=0.000001):
    """Computes random features for the softmax kernel using FAVOR+ mechanism.

    Computes random features for the softmax kernel using FAVOR+ mechanism from
    https://arxiv.org/pdf/2009.14794.pdf.

    Args:
      data: input data tensor of the shape [B, L, H, D], where: B - batch
        dimension, L - attention dimensions, H - heads, D - features.
      is_query: indicates whether input data is a query oor key tensor.
      projection_matrix: random Gaussian matrix of shape [M, D], where M stands
        for the number of random features and each D x D sub-block has pairwise
        orthogonal rows.
      numerical_stabilizer: small positive constant for numerical stability.

    Returns:
      Corresponding kernel feature map.
    """
    data_normalizer = 1.0 / (
        tf.math.sqrt(tf.math.sqrt(tf.dtypes.cast(data.shape[-1], tf.float32))))
    data = data_normalizer * data
    ratio = 1.0 / tf.math.sqrt(
        tf.dtypes.cast(projection_matrix.shape[0], tf.float32))
    data_dash = tf.einsum("blhd,md->blhm", data, projection_matrix)
    diag_data = tf.math.square(data)
    diag_data = tf.math.reduce_sum(
        diag_data, axis=tf.keras.backend.ndim(data) - 1)
    diag_data = diag_data / 2.0
    diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
    last_dims_t = (len(data_dash.shape) - 1,)
    attention_dims_t = (len(data_dash.shape) - 3,)
    if is_query:
        data_dash = ratio * (
            tf.math.exp(data_dash - diag_data - tf.math.reduce_max(
                data_dash, axis=last_dims_t, keepdims=True)) + numerical_stabilizer)
    else:
        data_dash = ratio * (
            tf.math.exp(data_dash - diag_data - tf.math.reduce_max(
                data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) +
            numerical_stabilizer)

    return data_dash


def noncausal_numerator(qs, ks, vs):
    """Computes not-normalized FAVOR noncausal attention AV.

    Args:
      qs: query_prime tensor of the shape [L,B,H,M].
      ks: key_prime tensor of the shape [L,B,H,M].
      vs: value tensor of the shape [L,B,H,D].

    Returns:
      Not-normalized FAVOR noncausal attention AV.
    """
    kvs = tf.einsum("lbhm,lbhd->bhmd", ks, vs)
    return tf.einsum("lbhm,bhmd->lbhd", qs, kvs)


def noncausal_denominator(qs, ks):
    """Computes FAVOR normalizer in noncausal attention.

    Args:
      qs: query_prime tensor of the shape [L,B,H,M].
      ks: key_prime tensor of the shape [L,B,H,M].

    Returns:
      FAVOR normalizer in noncausal attention.
    """
    all_ones = tf.ones([ks.shape[0]])
    ks_sum = tf.einsum("lbhm,l->bhm", ks, all_ones)
    return tf.einsum("lbhm,bhm->lbh", qs, ks_sum)


@tf.custom_gradient
def causal_numerator(qs, ks, vs):
    """Computes not-normalized FAVOR causal attention A_{masked}V.

    Args:
      qs: query_prime tensor of the shape [L,B,H,M].
      ks: key_prime tensor of the shape [L,B,H,M].
      vs: value tensor of the shape [L,B,H,D].

    Returns:
      Not-normalized FAVOR causal attention A_{masked}V.
    """

    result = []
    sums = tf.zeros_like(tf.einsum("ijk,ijl->ijkl", ks[0], vs[0]))

    for index in range(qs.shape[0]):
        sums = sums + tf.einsum("ijk,ijl->ijkl", ks[index], vs[index])
        result.append(tf.einsum("ijkl,ijk->ijl", sums, qs[index])[None, Ellipsis])

    result = tf.concat(result, axis=0)

    def grad(res_grad):

        grads = tf.zeros_like(tf.einsum("ijk,ijl->ijkl", ks[0], vs[0]))

        gr_sums = sums

        q_grads = []
        k_grads = []
        v_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum("ijkl,ijl->ijk", gr_sums, res_grad[index])[None, Ellipsis])
            grads = grads + tf.einsum("ijk,ijl->ijkl", qs[index], res_grad[index])
            k_grads.append(tf.einsum("ijkl,ijl->ijk", grads, vs[index])[None, Ellipsis])
            v_grads.append(tf.einsum("ijkl,ijk->ijl", grads, ks[index])[None, Ellipsis])
            gr_sums = gr_sums - tf.einsum("ijk,ijl->ijkl", ks[index], vs[index])

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)
        v_grads = tf.concat(v_grads[::-1], axis=0)

        return q_grads, k_grads, v_grads

    return result, grad


@tf.custom_gradient
def causal_denominator(qs, ks):
    """Computes FAVOR normalizer in causal attention.

    Args:
      qs: query_prime tensor of the shape [L,B,H,M].
      ks: key_prime tensor of the shape [L,B,H,M].

    Returns:
      FAVOR normalizer in causal attention.
    """

    result = []
    sums = tf.zeros_like(ks[0])

    for index in range(qs.shape[0]):
        sums = sums + ks[index]
        result.append(tf.reduce_sum(qs[index] * sums, axis=2)[None, Ellipsis])

    result = tf.concat(result, axis=0)

    def grad(res_grad):

        k_grad = tf.zeros_like(ks[0])

        gr_sums = sums

        q_grads = []
        k_grads = []

        for index in range(qs.shape[0] - 1, -1, -1):

            q_grads.append(
                tf.einsum("ijk,ij->ijk", gr_sums, res_grad[index])[None, Ellipsis])
            k_grad = k_grad + tf.einsum("ijk,ij->ijk", qs[index], res_grad[index])
            k_grads.append(k_grad[None, Ellipsis])
            gr_sums = gr_sums - ks[index]

        q_grads = tf.concat(q_grads[::-1], axis=0)
        k_grads = tf.concat(k_grads[::-1], axis=0)

        return q_grads, k_grads

    return result, grad


def favor_attention(query,
                    key,
                    value,
                    kernel_transformation,
                    causal,
                    projection_matrix=None):
    """Computes FAVOR normalized attention.

    Args:
      query: query tensor.
      key: key tensor.
      value: value tensor.
      kernel_transformation: transformation used to get finite kernel features.
      causal: whether attention is causal or not.
      projection_matrix: projection matrix to be used.

    Returns:
      FAVOR normalized attention.
    """
    query_prime = kernel_transformation(query, True,
                                        projection_matrix)  # [B,L,H,M]
    key_prime = kernel_transformation(key, False, projection_matrix)  # [B,L,H,M]
    query_prime = tf.transpose(query_prime, [1, 0, 2, 3])  # [L,B,H,M]
    key_prime = tf.transpose(key_prime, [1, 0, 2, 3])  # [L,B,H,M]
    value = tf.transpose(value, [1, 0, 2, 3])  # [L,B,H,D]

    if causal:
        av_attention = causal_numerator(query_prime, key_prime, value)
        attention_normalizer = causal_denominator(query_prime, key_prime)
    else:
        av_attention = noncausal_numerator(query_prime, key_prime, value)
        attention_normalizer = noncausal_denominator(query_prime, key_prime)
    # TODO(kchoro): Add more comments.
    av_attention = tf.transpose(av_attention, [1, 0, 2, 3])
    attention_normalizer = tf.transpose(attention_normalizer, [1, 0, 2])
    attention_normalizer = tf.expand_dims(attention_normalizer,
                                          len(attention_normalizer.shape))
    return av_attention / attention_normalizer


class Attention(tf.keras.layers.Layer):
    """Multi-headed attention layer."""

    def __init__(self,
                 hidden_size,
                 num_heads,
                 attention_dropout,
                 kernel_transformation=relu_kernel_transformation,
                 numerical_stabilizer=0.001,
                 causal=False,
                 projection_matrix_type=None,
                 nb_random_features=0):
        """Initialize Attention.

        Args:
          hidden_size: int, output dim of hidden layer.
          num_heads: int, number of heads to repeat the same attention structure.
          attention_dropout: float, dropout rate inside attention for training.
          kernel_transformation: transformation used to produce kernel features for
            attention.
          numerical_stabilizer: used to bound away from zero kernel values.
          causal: whether attention is causal or not.
          projection_matrix_type: None if Identity should be used, otherwise random
            projection matrix will be applied.
          nb_random_features: number of random features to be used (relevant only if
            projection_matrix is not None).
        """
        if hidden_size % num_heads:
            raise ValueError(
                "Hidden size ({}) must be divisible by the number of heads ({})."
                .format(hidden_size, num_heads))

        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.kernel_transformation = kernel_transformation
        self.numerical_stabilizer = numerical_stabilizer
        self.causal = causal
        self.projection_matrix_type = projection_matrix_type
        self.nb_random_features = nb_random_features

    def build(self, input_shape):
        """Builds the layer."""
        # Layers for linearly projecting the queries, keys, and values.
        size_per_head = self.hidden_size // self.num_heads

        def _glorot_initializer(fan_in, fan_out):
            limit = math.sqrt(6.0 / (fan_in + fan_out))
            return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)

        attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
                                                    self.hidden_size)
        self.query_dense_layer = util.DenseEinsum(
            output_shape=(self.num_heads, size_per_head),
            kernel_initializer=attention_initializer,
            use_bias=False,
            name="query")
        self.key_dense_layer = util.DenseEinsum(
            output_shape=(self.num_heads, size_per_head),
            kernel_initializer=attention_initializer,
            use_bias=False,
            name="key")
        self.value_dense_layer = util.DenseEinsum(
            output_shape=(self.num_heads, size_per_head),
            kernel_initializer=attention_initializer,
            use_bias=False,
            name="value")

        output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
        self.output_dense_layer = util.DenseEinsum(
            output_shape=self.hidden_size,
            num_summed_dimensions=2,
            kernel_initializer=output_initializer,
            use_bias=False,
            name="output_transform")
        super(Attention, self).build(input_shape)

    def get_config(self):
        return {
            "hidden_size": self.hidden_size,
            "num_heads": self.num_heads,
            "attention_dropout": self.attention_dropout,
        }

    def call(self,
             query_input,
             source_input,
             bias,
             training,
             cache=None,
             decode_loop_step=None):
        """Apply attention mechanism to query_input and source_input.

        Args:
          query_input: A tensor with shape [batch_size, length_query, hidden_size].
          source_input: A tensor with shape [batch_size, length_source,
            hidden_size].
          bias: A tensor with shape [batch_size, 1, length_query, length_source],
            the attention bias that will be added to the result of the dot product.
          training: A bool, whether in training mode or not.
          cache: (Used during prediction) A dictionary with tensors containing
            results of previous attentions. The dictionary must have the items:
                {"k": tensor with shape [batch_size, i, heads, dim_per_head],
                 "v": tensor with shape [batch_size, i, heads, dim_per_head]} where
                   i is the current decoded length for non-padded decode, or max
                   sequence length for padded decode.
          decode_loop_step: An integer, step number of the decoding loop. Used only
            for autoregressive inference on TPU.

        Returns:
          Attention layer output with shape [batch_size, length_query, hidden_size]
        """
        # Linearly project the query, key and value using different learned
        # projections. Splitting heads is automatically done during the linear
        # projections --> [batch_size, length, num_heads, dim_per_head].
        query = self.query_dense_layer(query_input)
        key = self.key_dense_layer(source_input)
        value = self.value_dense_layer(source_input)

        if self.projection_matrix_type is None:
            projection_matrix = None
        else:
            dim = query.shape[-1]
            seed = tf.math.ceil(tf.math.abs(tf.math.reduce_sum(query) * BIG_CONSTANT))
            seed = tf.dtypes.cast(seed, tf.int32)
            projection_matrix = create_projection_matrix(
                self.nb_random_features, dim, seed=seed)

        if cache is not None:
            # Combine cached keys and values with new keys and values.
            if decode_loop_step is not None:
                cache_k_shape = cache["k"].shape.as_list()
                indices = tf.reshape(
                    tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
                    [1, cache_k_shape[1], 1, 1])
                key = cache["k"] + key * indices
                cache_v_shape = cache["v"].shape.as_list()
                indices = tf.reshape(
                    tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
                    [1, cache_v_shape[1], 1, 1])
                value = cache["v"] + value * indices
            else:
                key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
                value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)

            # Update cache
            cache["k"] = key
            cache["v"] = value

        attention_output = favor_attention(query, key, value,
                                           self.kernel_transformation, self.causal,
                                           projection_matrix)
        attention_output = self.output_dense_layer(attention_output)
        return attention_output


class SelfAttention(Attention):
    """Multiheaded self-attention layer."""

    def call(self,
             query_input,
             bias,
             training,
             cache=None,
             decode_loop_step=None):
        return super(SelfAttention, self).call(query_input, query_input, bias,
                                               training, cache, decode_loop_step)
